#' Estimator for the average marginal component effect (AMCE)

#' Function created by Mikael Poul Johannesson

#' Estimates the average causal mediation effect (AMCE) for all

#' provided treatments. Can estimate clustered and/or robust standard

#' errors.

#'

#' @importFrom stats as.formula lm na.omit

#' @importFrom broom tidy

#' @importFrom sandwich estfun sandwich

#' @importFrom lmtest coeftest

#' @importFrom tidyselect vars_select

#' @importFrom rlang quos is_empty enexpr quo_name

#' 

#'

#' @param data Data

#' @param post Post treatment variable.

#' @param ... The treatment components (without quotes).

#' @param cluster Variable indicating clusters if SEs should be

#'   clustered.

#' @param subgroup Variable(s) indicating which subgroups, if any, to

#'   independently calculate AMCEs for. If multiple subgroups are

#'   provided, then it will iterate over each combination.

#' @param diff If \code{diff} is provided, it will calculate the

#'   difference in AMCE between the unique values of \code{diff}.

#' @param subset Subset to estimate AMCEs for.

#' @keywords AMCE

#' @examples

#' \dontrun{

#'  amce <- amce(data, post, treatment_1, treatment_2)

#' }

#' @export

amce <- function(data, post, ...,
                 
                 cluster = NULL,
                 
                 subgroup = NULL,
                 
                 diff = NULL,
                 
                 subset = NULL) {
  
  
  
  if (!is.null(subset))
    
    data <- data[subset, ]
  
  
  
  have_diff <- !is.null(diff)
  
  have_subgroup <- !is.null(subgroup)
  
  have_cluster <- !is.null(cluster)
  
  
  
  post <- quo_name(enexpr(post))
  
  
  
  quos <- quos(...)
  
  if (is_empty(quos)) {
    
    treat <- names(data)
    
  } else {
    
    treat <- unname(tidyselect::vars_select(names(data), !!! quos))
    
  }
  
  treat <- setdiff(treat, c(post, subgroup, diff, cluster))
  
  
  
  ## Turn all vars into factors.
  
  for (var in c(treat, subgroup, diff)) {
    
    if (!is.factor(data[[var]])) {
      
      data[[var]] <- factor(data[[var]])
      
    }
    
  }
  
  
  
  ## Make compute list to lapply over.
  
  if (have_subgroup) {
    
    subgroup_values <- lapply(subgroup, function(x) levels(factor(data[[x]])))
    
    subgroup_values <- c(subgroup_values, list(treat))
    
    names(subgroup_values) <- c(subgroup, "treat")
    
    
    
    if (have_diff) {
      
      subgroup_values <- c(subgroup_values, list(diff))
      
      names(subgroup_values) <- c(subgroup, "treat", "diff")
      
    }
    
    
    
    grid <- expand.grid(subgroup_values, stringsAsFactors = FALSE)
    
    
    
  } else if (have_diff) {
    
    
    
    grid <- expand.grid(
      
      treat = treat,
      
      diff = diff,
      
      stringsAsFactors = FALSE
      
    )
    
    
    
  } else {
    
    
    
    grid <- data.frame(
      
      treat = treat,
      
      stringsAsFactors = FALSE
      
    )
    
    
    
  }
  
  
  
  res <- lapply(1:nrow(grid), function(i) {
    
    
    
    .treat <- grid$treat[i]
    
    .diff <- grid$diff[i]
    
    .data <- data
    
    
    
    ## Filter subgroup
    
    if (have_subgroup) {
      
      for (.sub in subgroup) {
        
        .data <- .data[.data[[.sub]] == grid[[.sub]][i], ]
        
        if (nrow(.data) == 0) {
          
          warn_msg <- c(
            
            "No rows left in subgroup when:\n",
            lapply(subgroup, function(x)
              
              paste0(x, " == ", gsub("\n", " ", grid[[x]][i]), "\n"))
            
          )
          
          warning(warn_msg)
          
          return(NULL)
          
        }
        
      }
      
    }
    
    
    
    
    
    if (have_diff) {
      
      formula <- paste0(post, " ~ ",
                        
                        .treat, " + ",
                        
                        .diff, " + ",
                        
                        .treat, ":", .diff)
      
    } else {
      
      formula <- paste0(post, " ~ ", .treat)
      
    }
    
    
    
    ## Estimator that can return cluster-robust se.
    
    fit <- estimator_regression(
      
      formula = formula,
      
      data = .data,
      
      cluster = cluster
      
    )
    
    
    
    ## Baseline for reference category
    
    baseline <- data.frame(
      
      term = levels(.data[[.treat]])[1],
      
      estimate = 0,
      
      std.error = 0,
      
      statistic = NA,
      
      p.value = NA
      
    )
    
    
    
    if (have_diff) {
      
      baseline <- cbind(
        
        baseline,
        
        data.frame(
          
          diff_variable = .diff,
          
          diff_value = levels(.data[[.diff]])[-1])
        
      )
      
    }
    
    
    
    ## Tidy estimate to return
    
    est <- fit
    
    
    
    if (have_diff) {
      
      est <- est[grepl("\\:", est$term), ]
      
      est$diff_variable <- .diff
      
      est$diff_value <- gsub("^.*\\:(.*)$", "\\1", est$term)
      
      est$diff_value <- gsub(paste0("^", .diff), "", est$diff_value)
      
    } else {
      
      est <- est[est$term != "(Intercept)", ]
      
    }
    
    
    
    est$term <- gsub("^(.*)\\:.*$", "\\1", est$term)
    
    est$term <- gsub(paste0("^", .treat), "", est$term)
    
    
    
    est <- rbind(baseline, est)
    
    est$treatment <- .treat
    
    names(est)[1] <- "value"
    
    
    
    if (have_subgroup) {
      
      for (.sub in subgroup) {
        
        est[[.sub]] <- grid[[.sub]][i]
        
      }
      
    }
    
    
    
    est
    
  })
  
  res <- do.call("rbind", res)
  
  
  
  res$treatment <- factor(res$treatment, levels = unique(res$treatment))
  
  res$value <- factor(res$value, levels = unique(res$value))
  
  res$value_order <- 1:nrow(res)
  
  names(res)[names(res) == "std.error"] <- "std_error"
  
  names(res)[names(res) == "p.value"] <- "p_value"
  
  
  
  return(res)
  
}



#' @importFrom stats as.formula lm na.omit

#' @importFrom broom tidy

#' @importFrom sandwich estfun sandwich

#' @importFrom lmtest coeftest

estimator_regression <- function(formula, data,
                                 
                                 cluster = NULL,
                                 
                                 weights = NULL) {
  
  
  
  if (is.character(formula)) formula <- as.formula(formula)
  
  
  
  all_vars <- all.vars(formula)
  
  if (!is.null(cluster)) all_vars <- c(all_vars, cluster)
  
  if (!is.null(weights)) all_vars <- c(all_vars, weights)
  
  data <- data[, all_vars]
  
  data <- na.omit(data)
  
  
  
  if (!is.null(weights)) {
    
    we <- data[[weights]]
    
  } else {
    
    we <- NULL
    
  }
  
  
  
  fit <- lm(
    
    formula = formula,
    
    data = data,
    
    weights = we
    
  )
  
  
  
  if (!is.null(cluster)) {
    
    
    
    cl <- data[[cluster]]
    
    M <- length(unique(cl))
    
    N <- length(cl)
    
    
    
    dfc <- (M / (M - 1)) * ((N - 1) / (N - fit$rank))
    
    u <- apply(estfun(fit), 2, function(x) tapply(x, cl, sum))
    
    vcov_cl <- dfc * sandwich(fit, meat. = crossprod(u) / N)
    
    
    
    out <- coeftest(fit, vcov_cl)
    
    out <- tidy(out)
    
    
    
  } else {
    
    
    
    out <- tidy(fit)
    
    
    
  }
  
  
  
  out
  
}